# src/pointer/translator_cc.py
# Compact-curvature translator (LoG/DoG → S+ → Poisson/optical → radial & lensing fits)

from dataclasses import dataclass
from typing import Dict, Any, Tuple, Optional
import numpy as np
from numpy.fft import rfft2, irfft2, fft2, ifft2  # kept for compatibility
from scipy.ndimage import (
    gaussian_filter, gaussian_laplace, binary_opening, binary_closing,
    binary_fill_holes, label, generate_binary_structure
)

# ------------------------------
# Utilities
# ------------------------------

def _npz_load_e0(path: str) -> np.ndarray:
    """Load E0 from NPZ robustly: use 'E0' key if present, else first array."""
    d = np.load(path)
    return d["E0"] if "E0" in d.files else d[d.files[0]]

def _zscore(x: np.ndarray) -> np.ndarray:
    m, s = np.mean(x), np.std(x)
    return (x - m) / (s + 1e-12)

def _parse_threshold(spec: str, arr: np.ndarray) -> float:
    """
    'quantile:0.90' -> top-10% cutoff, 'median', '75pct', 'fixed:1.23'
    """
    s = spec.strip().lower()
    if s.startswith("quantile:"):
        q = float(s.split(":")[1])
        return float(np.quantile(arr, q))
    if s == "median":
        return float(np.median(arr))
    if s.endswith("pct"):
        q = float(s.replace("pct",""))/100.0
        return float(np.quantile(arr, q))
    if s.startswith("fixed:"):
        return float(s.split(":")[1])
    # default: 90th pct
    return float(np.quantile(arr, 0.90))

def _structure_for_connectivity(conn: int) -> np.ndarray:
    # 2D: 4/8 connectivity via generate_binary_structure
    if conn == 4:
        return generate_binary_structure(2, 1)
    return generate_binary_structure(2, 2)

def _remove_small_components(mask: np.ndarray, min_pixels: int) -> np.ndarray:
    if min_pixels <= 1:
        return mask
    st = _structure_for_connectivity(8)
    lab, nlab = label(mask, structure=st)
    if nlab == 0:
        return mask
    counts = np.bincount(lab.ravel())
    # counts[0] is background
    keep = np.where(counts >= min_pixels)[0]
    keep = keep[keep != 0]
    if len(keep) == 0:
        return np.zeros_like(mask, dtype=bool)
    out = np.isin(lab, keep)
    return out

def _largest_component(mask: np.ndarray) -> np.ndarray:
    st = _structure_for_connectivity(8)
    lab, nlab = label(mask, structure=st)
    if nlab == 0:
        return np.zeros_like(mask, dtype=bool)
    counts = np.bincount(lab.ravel())
    idx = np.argmax(counts[1:]) + 1
    return lab == idx

def _next_pow2(n: int) -> int:
    return 1 << (int(n - 1).bit_length())

def _softened_kernel(L: int, eps: float) -> np.ndarray:
    """
    Build (2L-1)x(2L-1) kernel K[i,j] = 1/sqrt((i-(L-1))^2 + (j-(L-1))^2 + eps^2).
    """
    N = 2*L - 1
    y = np.arange(N) - (L - 1)
    x = np.arange(N) - (L - 1)
    X, Y = np.meshgrid(x, y, indexing='xy')
    K = 1.0 / np.sqrt((X*X + Y*Y).astype(np.float64) + (eps*eps))
    return K

def _fft_convolve2d_aperiodic(a: np.ndarray, k: np.ndarray) -> np.ndarray:
    """
    Linear (aperiodic) convolution via FFT with zero-padding; returns LxL crop.
    a: LxL, k: (2L-1)x(2L-1). Output: LxL.
    """
    Lx, Ly = a.shape
    kx, ky = k.shape
    Nx = _next_pow2(Lx + kx - 1)
    Ny = _next_pow2(Ly + ky - 1)
    A = np.zeros((Nx, Ny), dtype=np.float64)
    K = np.zeros((Nx, Ny), dtype=np.float64)
    A[:Lx, :Ly] = a
    K[:kx, :ky] = k
    F = np.fft.rfftn(A)
    G = np.fft.rfftn(K)
    C = np.fft.irfftn(F * G, s=(Nx, Ny))
    # crop out the valid region
    ox = kx - 1
    oy = ky - 1
    return C[ox:ox+Lx, oy:oy+Ly].copy()

def _gradient_mag(V: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(V)
    return np.sqrt(gx*gx + gy*gy)

def _radial_profile(arr: np.ndarray, nbins: int, scheme: str, rmin: float, rmax: float
                   ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Return (r_centers, values_mean, counts) for arr, using bins in [rmin, rmax].
    Center at array center.
    """
    H, W = arr.shape
    cy, cx = (H-1)/2.0, (W-1)/2.0
    y = np.arange(H) - cy
    x = np.arange(W) - cx
    X, Y = np.meshgrid(x, y, indexing='xy')
    R = np.sqrt(X*X + Y*Y)

    mask = (R >= rmin) & (R <= rmax) & np.isfinite(arr)
    rvals = R[mask].ravel()
    avals = arr[mask].ravel()

    if scheme == "log":
        # avoid zero in edges
        rmin_eff = max(rmin, 1.0)
        edges = np.geomspace(rmin_eff, rmax, nbins+1)
    else:
        edges = np.linspace(rmin, rmax, nbins+1)

    idx = np.digitize(rvals, edges) - 1
    valid = (idx >= 0) & (idx < nbins)
    idx = idx[valid]
    rvals = rvals[valid]; avals = avals[valid]

    counts = np.bincount(idx, minlength=nbins).astype(np.int64)
    sums = np.bincount(idx, weights=avals, minlength=nbins)
    means = np.zeros(nbins, dtype=np.float64)
    nz = counts > 0
    means[nz] = sums[nz] / counts[nz]

    centers = 0.5*(edges[:-1] + edges[1:])
    return centers, means, counts

def _fit_powerlaw(x: np.ndarray, y: np.ndarray, weights: Optional[np.ndarray]=None
                 ) -> Tuple[float, float]:
    """
    Fit log(y) ~ a + s*log(x). Returns (slope s, R^2). Ignores nonpositive y.
    """
    mask = (x > 0) & (y > 0) & np.isfinite(x) & np.isfinite(y)
    x, y = x[mask], y[mask]
    if len(x) < 6:
        return np.nan, np.nan
    lx, ly = np.log(x), np.log(y)

    if weights is None:
        A = np.vstack([np.ones_like(lx), lx]).T
        coef, _, _, _ = np.linalg.lstsq(A, ly, rcond=None)
    else:
        w = np.asarray(weights)[mask]
        W = np.diag(w)
        A = np.vstack([np.ones_like(lx), lx]).T
        coef = np.linalg.lstsq(W @ A, W @ ly, rcond=None)[0]
    a, s = coef[0], coef[1]
    yhat = a + s*lx
    ss_res = np.sum((ly - yhat)**2)
    ss_tot = np.sum((ly - ly.mean())**2)
    r2 = 1.0 - ss_res/(ss_tot + 1e-12)
    return float(s), float(r2)

def _alpha_small_angle(n: np.ndarray, b_vals: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Small-angle deflection for index n(x,y) ~ 1 + lambda*S+(x,y).
    α(b) ≈ ∫ (∂⊥ n)/n dx. We take rays along x, impact parameter b along +y from center.
    Returns (alpha(b), 1/b).
    """
    H, W = n.shape
    cy, cx = H//2, W//2
    # gradient along y:
    dy, dx = np.gradient(n)
    alphas = []
    invb = []
    for b in b_vals:
        y_up = cy + int(round(b))
        y_dn = cy - int(round(b))
        if y_up < 0 or y_up >= H or y_dn < 0 or y_dn >= H:
            continue
        # integrate across x with trapezoid rule
        num_up = dy[y_up, :]
        den_up = np.maximum(n[y_up, :], 1e-12)
        num_dn = dy[y_dn, :]
        den_dn = np.maximum(n[y_dn, :], 1e-12)
        a_up = np.trapz(num_up/den_up, dx=1.0)
        a_dn = np.trapz(num_dn/den_dn, dx=1.0)
        a = 0.5*(np.abs(a_up) + np.abs(a_dn))
        alphas.append(a)
        invb.append(1.0/max(b, 1e-12))
    if not alphas:
        return np.array([]), np.array([])
    return np.asarray(alphas), np.asarray(invb)

# ------------------------------
# CC translator
# ------------------------------

@dataclass
class CCConfig:
    operator: str = "LoG"              # "LoG" or "DoG"
    sigma_list: Tuple[int,...] = (1,2,3)
    normalize: str = "zscore"          # "zscore" or "none"
    threshold: str = "quantile:0.90"   # or "median" / "75pct" / "fixed:1.2"
    connectivity: int = 8              # 4 or 8
    morph_open: int = 1
    morph_close: int = 1
    fill_holes: bool = True
    remove_small_px: int = 50
    keep: str = "largest"              # "largest" or "all"

    # geometry / optics
    sheet3D: bool = True               # keep for future extensibility
    epsilon_soften: str = "0.5*sigma"  # uses max(sigma_list) if '*sigma'
    padding_factor: int = 0            # unused with exact aperiodic FFT; keep for completeness
    lambda_sweep: Tuple[float,...] = (0.2, 0.5, 1.0)

    # radial & lensing fits
    radial_bins_scheme: str = "log"    # "log" or "linear"
    radial_bins: int = 48
    fit_window_min: str = "3*sigma"    # rmin (accepts k*sigma or a number)
    fit_window_max_fracL: float = 0.30 # rmax = this * L
    lensing_b_min: int = 8
    lensing_b_max: int = 90
    lensing_b_n: int = 32
    regression_weights: str = "counts" # "counts" or "none"

def _cc_response(E0: np.ndarray, cfg: CCConfig) -> np.ndarray:
    """Multi-scale LoG/DoG response; return positive curvature magnitude map."""
    sigmas = list(cfg.sigma_list)
    resp = np.zeros_like(E0, dtype=np.float64)
    if cfg.operator.lower() == "log":
        for s in sigmas:
            # -LoG gives positive values for bright peaks
            r = -gaussian_laplace(E0, sigma=float(s))
            resp = np.maximum(resp, r)
    else:
        # DoG: use adjacent scales; take positive part
        sigmas = sorted(sigmas)
        for i in range(len(sigmas)-1):
            s1, s2 = sigmas[i], sigmas[i+1]
            r = gaussian_filter(E0, s1) - gaussian_filter(E0, s2)
            resp = np.maximum(resp, r)
    return resp

def _build_Splus(resp: np.ndarray, cfg: CCConfig) -> Tuple[np.ndarray, Dict[str, Any]]:
    """Normalize → threshold → morphology → largest component as S+ mask."""
    if cfg.normalize.lower() == "zscore":
        rnorm = _zscore(resp)
    else:
        rnorm = resp.copy()

    # Threshold on the positive tail when using quantile (prevents flooding)
    thr_spec = cfg.threshold
    if isinstance(thr_spec, str) and thr_spec.lower().startswith("quantile:"):
        q = float(thr_spec.split(":")[1])
        pos = rnorm[rnorm > 0]
        if pos.size >= 100:
            thr = float(np.quantile(pos, q))
            thr_from_pos_tail = True
        else:
            thr = float(np.quantile(rnorm, q))
            thr_from_pos_tail = False
    else:
        thr = _parse_threshold(thr_spec, rnorm)
        thr_from_pos_tail = False

    mask = rnorm >= thr

    st = _structure_for_connectivity(cfg.connectivity)
    if cfg.morph_open > 0:
        mask = binary_opening(mask, structure=st, iterations=cfg.morph_open)
    if cfg.morph_close > 0:
        mask = binary_closing(mask, structure=st, iterations=cfg.morph_close)
    if cfg.fill_holes:
        mask = binary_fill_holes(mask)
    if cfg.remove_small_px > 1:
        mask = _remove_small_components(mask, cfg.remove_small_px)

    if cfg.keep == "largest":
        Splus = _largest_component(mask)
    else:
        Splus = mask

    meta = {
        "threshold_value": float(thr),
        "mask_coverage": float(Splus.mean()),
        "n_pixels": int(np.count_nonzero(Splus)),
        "threshold_from_positive_tail": bool(thr_from_pos_tail),
    }
    return Splus.astype(np.float64), meta

def _parse_eps(cfg: CCConfig) -> float:
    if isinstance(cfg.epsilon_soften, (int, float)):
        return float(cfg.epsilon_soften)
    s = str(cfg.epsilon_soften).lower()
    if "*sigma" in s:
        num = float(s.split("*sigma")[0])
        return num * max(cfg.sigma_list)
    try:
        return float(s)
    except Exception:
        return 1.0

def _parse_rmin(spec, sigma_max: float) -> float:
    """
    Accepts numbers (e.g., 40), or strings like '3*sigma', '5*sigma', '4.5*sigma'.
    Falls back to 3*sigma on parse errors.
    """
    if isinstance(spec, (int, float)):
        return float(spec)
    s = str(spec).strip().lower()
    if "*sigma" in s:
        try:
            k = float(s.split("*sigma")[0])
            return k * float(sigma_max)
        except Exception:
            return 3.0 * float(sigma_max)
    try:
        return float(s)
    except Exception:
        return 3.0 * float(sigma_max)

def cc_translate_and_fit(
    E0: np.ndarray,
    cfg: CCConfig,
    L: int
) -> Dict[str, Any]:
    """
    Main routine: builds S+ from E0, computes V and |∇V|, radial power-law fits,
    lensing α(b) vs 1/b slopes for given λ sweep. Returns dict of metrics.
    """
    # 1) CC response
    resp = _cc_response(E0, cfg)
    # Clamp to positive curvature to avoid flooding the mask
    resp = np.maximum(resp, 0.0)

    # 2) S+ mask
    Splus, s_meta = _build_Splus(resp, cfg)

    # 3) Potential via aperiodic convolution with softened 1/r kernel
    eps = _parse_eps(cfg)
    K = _softened_kernel(L, eps)
    V = _fft_convolve2d_aperiodic(Splus, K)
    V = np.maximum(V, 1e-12)  # keep positive

    # 4) Field magnitude
    G = _gradient_mag(V)
    G = np.maximum(G, 1e-16)

    # 5) Radial profiles & fits
    sigma_max = max(cfg.sigma_list)
    rmin = _parse_rmin(cfg.fit_window_min, sigma_max)
    rmax = cfg.fit_window_max_fracL * L
    rbins = cfg.radial_bins
    scheme = cfg.radial_bins_scheme

    rV, profV, cntV = _radial_profile(V, nbins=rbins, scheme=scheme, rmin=rmin, rmax=rmax)
    rG, profG, cntG = _radial_profile(G, nbins=rbins, scheme=scheme, rmin=rmin, rmax=rmax)

    wV = cntV if cfg.regression_weights=="counts" else None
    wG = cntG if cfg.regression_weights=="counts" else None
    s_phi, r2_phi = _fit_powerlaw(rV, profV, weights=wV)
    s_grad, r2_grad = _fit_powerlaw(rG, profG, weights=wG)

    # 6) Optics: lensing α(b) vs 1/b for lambda sweep
    bmin, bmax, nb = cfg.lensing_b_min, cfg.lensing_b_max, cfg.lensing_b_n
    bvals = np.linspace(bmin, bmax, nb)
    lens = {}
    for lam in cfg.lambda_sweep:
        n_index = 1.0 + lam * Splus
        alpha, invb = _alpha_small_angle(n_index, bvals)
        if len(alpha) >= 6:
            # linear fit alpha ~ m*(1/b) + c
            x = invb; y = alpha
            A = np.vstack([np.ones_like(x), x]).T
            coef, _, _, _ = np.linalg.lstsq(A, y, rcond=None)
            c0, m = coef[0], coef[1]
            yhat = c0 + m*x
            ss_res = np.sum((y - yhat)**2)
            ss_tot = np.sum((y - y.mean())**2)
            r2 = 1.0 - ss_res/(ss_tot + 1e-12)
            lens[lam] = {"alpha_slope": float(m), "alpha_r2": float(r2)}
        else:
            lens[lam] = {"alpha_slope": np.nan, "alpha_r2": np.nan}

    out = {
        "cc_meta": s_meta,
        "radial": {
            "rmin": float(rmin), "rmax": float(rmax), "nbins": int(rbins),
            "s_phi": float(s_phi), "r2_phi": float(r2_phi),
            "s_grad": float(s_grad), "r2_grad": float(r2_grad)
        },
        "lensing": {f"{lam:.1f}": lens[lam] for lam in cfg.lambda_sweep}
    }
    return out
